import argparse
import json
import os
import os.path as osp
from functools import partial
from pathlib import Path
from typing import Dict, List
from multiprocessing import Process, Queue
import tqdm
import numpy as np
from mmengine import list_dir_or_file
from transformers import AutoTokenizer
from typing import Callable
from xpuyu.datasets.ftdp import qwen as qwen_role_cfg
from xpuyu.datasets.ftdp import ftdp_tokenize


def chatml_format(
    processed_data,
    tokenizer,
    role_cfg,
    encode_json=True,
):

    training_data = ftdp_tokenize(tokenizer, processed_data, role_cfg)
    token_ids = training_data['input_ids']
    labels = training_data['labels']
    token_ids = [
        x if l == x else -x for x, l in zip(token_ids, labels)
    ]
    if encode_json:
        line = str.encode(json.dumps({'tokens': token_ids}) + '\n')
        return line, len(token_ids)
    return token_ids, len(token_ids)


def tokenize_worker(
    tokenize_fun: Callable,
    data_queue: Queue,
    out_queue: Queue,
):
    while True:
        data_chunk = data_queue.get()

        if data_chunk is None:
            out_queue.put(None)
            break
        chunk_results = []
        for idx, data in data_chunk:
            chunk_results.append([idx, tokenize_fun(data)])
        out_queue.put(chunk_results)


def chunk_data_to_queue(data_queue: Queue, data: List[Dict], chunk_size: int,
                        nproc):
    data_iter = iter(data)
    chunk_data = []
    while True:
        try:
            item = next(data_iter)
        except StopIteration:
            break
        chunk_data.append(item)
        if len(chunk_data) == chunk_size:
            data_queue.put(chunk_data)
            chunk_data = []
    if chunk_data:
        data_queue.put(chunk_data)

    for _ in range(nproc):
        data_queue.put(None)


def track_progress(tokenize_fun_p, dataset, nproc, task_num, chunksize,
                   description):
    processes = []
    data_queue = Queue()
    output_queue = Queue()
    bar = tqdm.tqdm(total=task_num, desc=description)
    # task_id = bar.add_task(total=task_num, description=description)
    dataset = enumerate(dataset)
    chunk_data_to_queue(data_queue, dataset, chunksize, nproc)
    for _ in range(nproc):
        process = Process(
            target=tokenize_worker,
            args=(tokenize_fun_p, data_queue, output_queue))
        process.start()
        processes.append(process)

    results = []
    finished_process = 0
    while finished_process < nproc:
        chunk_results = output_queue.get()
        if chunk_results is None:
            finished_process += 1
            continue
        results.extend(chunk_results)
        bar.update(len(chunk_results))
        bar.refresh()
    results = map(lambda x: x[1], sorted(results, key=lambda x: x[0]))
    return results


def write_bin_meta_bin(path, dataset_name, filename, samples):
    train_path = osp.join(path, f'train/cn/{dataset_name}')
    valid_path = osp.join(path, f'valid/cn/{dataset_name}')
    train_dir = Path(train_path)
    valid_dir = Path(valid_path)
    train_dir.mkdir(exist_ok=True, parents=True)
    valid_dir.mkdir(exist_ok=True, parents=True)
    train_f = open(train_dir.joinpath(f'{filename}.bin'), 'wb')
    valid_f_path = valid_dir.joinpath(f'{filename}.bin')
    valid_f = open(valid_f_path, 'wb')
    print(train_dir)
    print(valid_dir)
    train_tokens = 0
    valid_tokens = 0
    last_train_position = 0
    last_valid_position = 0
    train_samples = 0
    valid_samples = 0
    train_meta = []
    valid_meta = []
    for line, token_num in samples:
        train_tokens += token_num
        train_f.write(line)
        train_meta.append((last_train_position, token_num))
        last_train_position += len(line)
        train_samples += 1
        if (train_samples) % 100 == 0:  # ?
            valid_tokens += token_num
            valid_f.write(line)
            valid_meta.append((last_valid_position, token_num))
            last_valid_position += len(line)
            valid_samples += 1
    train_f.close()
    valid_f.close()
    np.save(open(train_dir.joinpath(f'{filename}.bin.meta'), 'wb'), train_meta)

    # remove the length of `valid_samples` is less than 500
    # 500 is a magic number, you can change it to any number you want
    # the number must bigger the DP.
    if valid_samples > 500:
        np.save(
            open(valid_dir.joinpath(f'{filename}.bin.meta'), 'wb'), valid_meta)
    else:
        print(f'{valid_f_path} is removed because the number of',
              f'`valid_samples`({valid_samples}) is less than 500')
        os.remove(valid_f_path)
    return train_tokens, valid_tokens, train_samples, valid_samples


def tokenize_and_save(tokenizer, processed_dir, tokenized_dir):
    tokenized_save_dir = osp.join(tokenized_dir, 'chatml_llamav13_32k')
    data_dir = processed_dir
    all_train_tokens = 0
    all_valid_tokens = 0
    all_train_samples = 0
    all_valid_samples = 0

    for filename in list_dir_or_file(data_dir, recursive=True, list_dir=False):
        file_path = os.path.join(data_dir, filename)
        if '/processed/' not in file_path:
            continue
        assert '.jsonl' in filename

        # dataset name such as char_x10_chat_format
        dataset_name = filename.split(os.sep)[0]

        # Hardcode here to skip tokenizing the file if it already exists
        # (Refactor the `write_bin_meta_bin`!).
        train_f = osp.join(tokenized_save_dir, 'train', 'cn', dataset_name,
                           f'{osp.splitext(osp.basename(filename))[0]}.bin')
        if osp.isfile(train_f):
            print(f'{train_f} already exists, skip it')
            continue

        tokenize_fun = partial(
            chatml_format,
            tokenizer=tokenizer,
            role_cfg=qwen_role_cfg)
        samples = []
        with open(file_path) as f:
            dataset = f.readlines()
        task_num = len(dataset)
        dataset = map(lambda x: json.loads(x), dataset)

        for sample in track_progress(
                tokenize_fun,
                dataset,
                nproc=96,
                task_num=task_num,
                chunksize=96,
                description=f'{os.path.basename(file_path)}...'):
            samples.append(sample)

        train_tokens, valid_tokens, train_samples, valid_samples = write_bin_meta_bin(  # noqa E501
            path=tokenized_save_dir,
            dataset_name=dataset_name,
            samples=samples,
            filename=osp.splitext(osp.basename(filename))[0])
        if train_tokens is None:
            print(f'{osp.splitext(osp.basename(filename))[0]} already '
                  'exists, skip it')
            continue

        print(f'train_tokens {train_tokens}', flush=True)
        print(f'train_samples {train_samples}')
        print(f'valid tokens {valid_tokens}')
        print(f'valid_samples {valid_samples}')
        all_train_tokens += train_tokens
        all_valid_tokens += valid_tokens
        all_train_samples += train_samples
        all_valid_samples += valid_samples

    print(f'all train tokens {all_train_tokens}')
    print(f'all train samples {all_train_samples}')
    print(f'all valid tokens {all_valid_tokens}')
    print(f'all valid samples {all_valid_samples}')

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--processed-dir', help='The folder to save untokenized data.')
    parser.add_argument(
        '--tokenized-dir', help='The folder to save tokenized data.')
    parser.add_argument(
        '--tokenizer-path', help='The path to the hf tokenizer.')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_path, trust_remote_code=True, padding_side='right')
    tokenize_and_save(tokenizer, args.processed_dir, args.tokenized_dir)


if __name__ == '__main__':
    main()
